{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/property_graph/agentic_graph_rag_vertex.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agentic GraphRAG Implementation with LlamaIndex and Vertex AI\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "------------\n",
    "### **Pre-requisites**\n",
    "- Set up a Google Cloud project\n",
    "- Create a Google Cloud storage bucket\n",
    "- Enable Vertex AI API\n",
    "\n",
    "\n",
    "#### References:\n",
    "This notebook is based on below references.\n",
    "\n",
    "1. GraphRAG Implementation with LlamaIndex\n",
    "- https://github.com/run-llama/llama_index/blob/main/docs/examples/cookbooks/GraphRAG_v1.ipynb\n",
    "\n",
    "2. Building Agentic RAG with Llamaindex Tutorial\n",
    "- https://learn.deeplearning.ai/courses/building-agentic-rag-with-llamaindex/lesson/2/router-query-engine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_______________"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphRAG Aproach\n",
    "\n",
    "The GraphRAG involves two steps:\n",
    "\n",
    "1. Graph Generation - Creates Graph, builds communities and its summaries over the given document.\n",
    "2. Answer to the Query - Use summaries of the communities created from step-1 to answer the query.\n",
    "\n",
    "**Graph Generation:**\n",
    "\n",
    "1. **Source Documents to Text Chunks:** Source documents are divided into smaller text chunks for easier processing.\n",
    "\n",
    "2. **Text Chunks to Element Instances:** Each text chunk is analyzed to identify and extract entities and relationships, resulting in a list of tuples that represent these elements.\n",
    "\n",
    "3. **Element Instances to Element Summaries:** The extracted entities and relationships are summarized into descriptive text blocks for each element using the LLM.\n",
    "\n",
    "4. **Element Summaries to Graph Communities:** These entities, relationships and summaries form a graph, which is subsequently partitioned into communities using algorithms using Heirarchical Leiden to establish a hierarchical structure.\n",
    "\n",
    "5. **Graph Communities to Community Summaries:** The LLM generates summaries for each community, providing insights into the dataset’s overall topical structure and semantics.\n",
    "\n",
    "**Answering the Query:**\n",
    "\n",
    "**Community Summaries to Global Answers:** The summaries of the communities are utilized to respond to user queries. This involves generating intermediate answers, which are then consolidated into a comprehensive global answer.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphRAG Pipeline Components\n",
    "\n",
    "Here are the different components we implemented to build all of the processes mentioned above.\n",
    "\n",
    "1. **Source Documents to Text Chunks:** Implemented using `SentenceSplitter` with a chunk size of 1024 and chunk overlap of 20 tokens.\n",
    "\n",
    "2. **Text Chunks to Element Instances AND Element Instances to Element Summaries:** Implemented using `GraphRAGExtractor`.\n",
    "\n",
    "3. **Element Summaries to Graph Communities AND Graph Communities to Community Summaries:** Implemented using `GraphRAGStore`.\n",
    "\n",
    "4. **Community Summaries to Global Answers:** Implemented using `GraphQueryEngine`.\n",
    "\n",
    "\n",
    "Let's check into each of these components and build GraphRAG pipeline.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "`graspologic` is used to use hierarchical_leiden for building communities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index graspologic numpy==1.24.4 scipy==1.12.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade google-cloud-aiplatform llama-index-vector-stores-vertexaivectorsearch llama-index llama_index-llms-vertex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "\n",
    "app = IPython.Application.instance()\n",
    "app.kernel.do_shutdown(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Authentication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "if \"google.colab\" in sys.modules:\n",
    "    from google.colab import auth\n",
    "\n",
    "    auth.authenticate_user()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Google Cloud project information and initialize Vertex AI SDK\n",
    "\n",
    "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
    "\n",
    "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT_ID = \"<your project>\"  # @param {type:\"string\"}\n",
    "LOCATION = \"us-central1\"  # @param {type:\"string\"}\n",
    "\n",
    "import vertexai\n",
    "\n",
    "vertexai.init(project=PROJECT_ID, location=LOCATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import modules needed\n",
    "from llama_index.core import (\n",
    "    StorageContext,\n",
    "    Settings,\n",
    "    VectorStoreIndex,\n",
    "    SummaryIndex,\n",
    "    SimpleDirectoryReader,\n",
    ")\n",
    "from llama_index.core.schema import TextNode\n",
    "from llama_index.core.vector_stores.types import (\n",
    "    MetadataFilters,\n",
    "    MetadataFilter,\n",
    "    FilterOperator,\n",
    ")\n",
    "from llama_index.llms.vertex import Vertex\n",
    "from llama_index.embeddings.vertex import VertexTextEmbedding\n",
    "from llama_index.vector_stores.vertexaivectorsearch import VertexAIVectorStore\n",
    "\n",
    "from typing import List, Optional\n",
    "from llama_index.core.vector_stores import FilterCondition\n",
    "from llama_index.core.tools import FunctionTool\n",
    "from llama_index.core import SimpleDirectoryReader\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "\n",
    "from llama_index.core.tools import QueryEngineTool\n",
    "from llama_index.core.vector_stores import MetadataFilters\n",
    "from pathlib import Path\n",
    "\n",
    "from llama_index.core.agent.workflow import FunctionAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# configure embedding model\n",
    "embed_model = VertexTextEmbedding(\n",
    "    model_name=\"text-embedding-004\",\n",
    "    project=PROJECT_ID,\n",
    "    location=LOCATION,\n",
    ")\n",
    "\n",
    "vertex_gemini = Vertex(\n",
    "    model=\"gemini-1.5-pro\",\n",
    "    temperature=0,\n",
    "    context_window=100000,\n",
    "    additional_kwargs={},\n",
    ")\n",
    "\n",
    "# setup the index/query process, ie the embedding model (and completion if used)\n",
    "Settings.embed_model = embed_model\n",
    "Settings.llm = vertex_gemini"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = vertex_gemini"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare documents as required by LlamaIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "urls = [\n",
    "    \"https://openreview.net/pdf?id=VtmBAGCN7o\",\n",
    "    \"https://openreview.net/pdf?id=6PmJoRfdaK\",\n",
    "    \"https://openreview.net/pdf?id=LzPWWPAdY4\",\n",
    "    \"https://openreview.net/pdf?id=VTF8yNQM66\",\n",
    "    \"https://openreview.net/pdf?id=hSyW5go0v8\",\n",
    "    \"https://openreview.net/pdf?id=9WD9KwssyT\",\n",
    "    \"https://openreview.net/pdf?id=yV6fD7LYkF\",\n",
    "    \"https://openreview.net/pdf?id=hnrB5YHoYu\",\n",
    "    \"https://openreview.net/pdf?id=WbWtOYIzIK\",\n",
    "    \"https://openreview.net/pdf?id=c5pwL0Soay\",\n",
    "    \"https://openreview.net/pdf?id=TpD2aG1h0D\",\n",
    "]\n",
    "\n",
    "papers = [\n",
    "    \"metagpt.pdf\",\n",
    "    \"longlora.pdf\",\n",
    "    \"loftq.pdf\",\n",
    "    \"swebench.pdf\",\n",
    "    \"selfrag.pdf\",\n",
    "    \"zipformer.pdf\",\n",
    "    \"values.pdf\",\n",
    "    \"finetune_fair_diffusion.pdf\",\n",
    "    \"knowledge_card.pdf\",\n",
    "    \"metra.pdf\",\n",
    "    \"vr_mcl.pdf\",\n",
    "]\n",
    "import requests\n",
    "\n",
    "\n",
    "def download_file(url, file_path):\n",
    "    \"\"\"Downloads a file from a given URL and saves it to the specified file path.\n",
    "\n",
    "    Args:\n",
    "        url: The URL of the file to download.\n",
    "        file_path: The path to save the downloaded file.\n",
    "    \"\"\"\n",
    "\n",
    "    response = requests.get(url, stream=True)\n",
    "    response.raise_for_status()  # Raise an exception for non-200 status codes\n",
    "\n",
    "    with open(file_path, \"wb\") as f:\n",
    "        for chunk in response.iter_content(chunk_size=1024):\n",
    "            if chunk:  # Filter out keep-alive new chunks\n",
    "                f.write(chunk)\n",
    "\n",
    "    print(f\"Downloaded file from {url} to {file_path}\")\n",
    "\n",
    "\n",
    "for url, paper in zip(urls, papers):\n",
    "    download_file(url, paper)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphRAGExtractor\n",
    "\n",
    "The GraphRAGExtractor class is designed to extract triples (subject-relation-object) from text and enrich them by adding descriptions for entities and relationships to their properties using an LLM.\n",
    "\n",
    "This functionality is similar to that of the `SimpleLLMPathExtractor`, but includes additional enhancements to handle entity, relationship descriptions. For guidance on implementation, you may look at similar existing [extractors](https://docs.llamaindex.ai/en/latest/examples/property_graph/dynamic_kg_extraction/?h=comparing).\n",
    "\n",
    "Here's a breakdown of its functionality:\n",
    "\n",
    "**Key Components:**\n",
    "\n",
    "1. `llm:` The language model used for extraction.\n",
    "2. `extract_prompt:` A prompt template used to guide the LLM in extracting information.\n",
    "3. `parse_fn:` A function to parse the LLM's output into structured data.\n",
    "4. `max_paths_per_chunk:` Limits the number of triples extracted per text chunk.\n",
    "5. `num_workers:` For parallel processing of multiple text nodes.\n",
    "\n",
    "\n",
    "**Main Methods:**\n",
    "\n",
    "1. `__call__:` The entry point for processing a list of text nodes.\n",
    "2. `acall:` An asynchronous version of __call__ for improved performance.\n",
    "3. `_aextract:` The core method that processes each individual node.\n",
    "\n",
    "\n",
    "**Extraction Process:**\n",
    "\n",
    "For each input node (chunk of text):\n",
    "1. It sends the text to the LLM along with the extraction prompt.\n",
    "2. The LLM's response is parsed to extract entities, relationships, descriptions for entities and relations.\n",
    "3. Entities are converted into EntityNode objects. Entity description is stored in metadata\n",
    "4. Relationships are converted into Relation objects. Relationship description is stored in metadata.\n",
    "5. These are added to the node's metadata under KG_NODES_KEY and KG_RELATIONS_KEY.\n",
    "\n",
    "**NOTE:** In the current implementation, we are using only relationship descriptions. In the next implementation, we will utilize entity descriptions during the retrieval stage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()\n",
    "\n",
    "from typing import Any, List, Callable, Optional, Union, Dict\n",
    "from IPython.display import Markdown, display\n",
    "\n",
    "from llama_index.core.async_utils import run_jobs\n",
    "from llama_index.core.indices.property_graph.utils import (\n",
    "    default_parse_triplets_fn,\n",
    ")\n",
    "from llama_index.core.graph_stores.types import (\n",
    "    EntityNode,\n",
    "    KG_NODES_KEY,\n",
    "    KG_RELATIONS_KEY,\n",
    "    Relation,\n",
    ")\n",
    "from llama_index.core.llms.llm import LLM\n",
    "from llama_index.core.prompts import PromptTemplate\n",
    "from llama_index.core.prompts.default_prompts import (\n",
    "    DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,\n",
    ")\n",
    "from llama_index.core.schema import TransformComponent, BaseNode\n",
    "from llama_index.core.bridge.pydantic import BaseModel, Field\n",
    "\n",
    "\n",
    "class GraphRAGExtractor(TransformComponent):\n",
    "    \"\"\"Extract triples from a graph.\n",
    "\n",
    "    Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.\n",
    "\n",
    "    Args:\n",
    "        llm (LLM):\n",
    "            The language model to use.\n",
    "        extract_prompt (Union[str, PromptTemplate]):\n",
    "            The prompt to use for extracting triples.\n",
    "        parse_fn (callable):\n",
    "            A function to parse the output of the language model.\n",
    "        num_workers (int):\n",
    "            The number of workers to use for parallel processing.\n",
    "        max_paths_per_chunk (int):\n",
    "            The maximum number of paths to extract per chunk.\n",
    "    \"\"\"\n",
    "\n",
    "    llm: LLM\n",
    "    extract_prompt: PromptTemplate\n",
    "    parse_fn: Callable\n",
    "    num_workers: int\n",
    "    max_paths_per_chunk: int\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        llm: Optional[LLM] = None,\n",
    "        extract_prompt: Optional[Union[str, PromptTemplate]] = None,\n",
    "        parse_fn: Callable = default_parse_triplets_fn,\n",
    "        max_paths_per_chunk: int = 10,\n",
    "        num_workers: int = 4,\n",
    "    ) -> None:\n",
    "        \"\"\"Init params.\"\"\"\n",
    "        from llama_index.core import Settings\n",
    "\n",
    "        if isinstance(extract_prompt, str):\n",
    "            extract_prompt = PromptTemplate(extract_prompt)\n",
    "\n",
    "        super().__init__(\n",
    "            llm=llm or Settings.llm,\n",
    "            extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,\n",
    "            parse_fn=parse_fn,\n",
    "            num_workers=num_workers,\n",
    "            max_paths_per_chunk=max_paths_per_chunk,\n",
    "        )\n",
    "\n",
    "    @classmethod\n",
    "    def class_name(cls) -> str:\n",
    "        return \"GraphExtractor\"\n",
    "\n",
    "    def __call__(\n",
    "        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any\n",
    "    ) -> List[BaseNode]:\n",
    "        \"\"\"Extract triples from nodes.\"\"\"\n",
    "        return asyncio.run(\n",
    "            self.acall(nodes, show_progress=show_progress, **kwargs)\n",
    "        )\n",
    "\n",
    "    async def _aextract(self, node: BaseNode) -> BaseNode:\n",
    "        \"\"\"Extract triples from a node.\"\"\"\n",
    "        assert hasattr(node, \"text\")\n",
    "\n",
    "        text = node.get_content(metadata_mode=\"llm\")\n",
    "        try:\n",
    "            llm_response = await self.llm.apredict(\n",
    "                self.extract_prompt,\n",
    "                text=text,\n",
    "                max_knowledge_triplets=self.max_paths_per_chunk,\n",
    "            )\n",
    "            entities, entities_relationship = self.parse_fn(llm_response)\n",
    "        except ValueError:\n",
    "            entities = []\n",
    "            entities_relationship = []\n",
    "\n",
    "        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])\n",
    "        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])\n",
    "        metadata = node.metadata.copy()\n",
    "        for entity, entity_type, description in entities:\n",
    "            metadata[\n",
    "                \"entity_description\"\n",
    "            ] = description  # Not used in the current implementation. But will be useful in future work.\n",
    "            entity_node = EntityNode(\n",
    "                name=entity, label=entity_type, properties=metadata\n",
    "            )\n",
    "            existing_nodes.append(entity_node)\n",
    "\n",
    "        metadata = node.metadata.copy()\n",
    "        for triple in entities_relationship:\n",
    "            subj, rel, obj, description = triple\n",
    "            subj_node = EntityNode(name=subj, properties=metadata)\n",
    "            obj_node = EntityNode(name=obj, properties=metadata)\n",
    "            metadata[\"relationship_description\"] = description\n",
    "            rel_node = Relation(\n",
    "                label=rel,\n",
    "                source_id=subj_node.id,\n",
    "                target_id=obj_node.id,\n",
    "                properties=metadata,\n",
    "            )\n",
    "\n",
    "            existing_nodes.extend([subj_node, obj_node])\n",
    "            existing_relations.append(rel_node)\n",
    "\n",
    "        node.metadata[KG_NODES_KEY] = existing_nodes\n",
    "        node.metadata[KG_RELATIONS_KEY] = existing_relations\n",
    "        return node\n",
    "\n",
    "    async def acall(\n",
    "        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any\n",
    "    ) -> List[BaseNode]:\n",
    "        \"\"\"Extract triples from nodes async.\"\"\"\n",
    "        jobs = []\n",
    "        for node in nodes:\n",
    "            jobs.append(self._aextract(node))\n",
    "\n",
    "        return await run_jobs(\n",
    "            jobs,\n",
    "            workers=self.num_workers,\n",
    "            show_progress=show_progress,\n",
    "            desc=\"Extracting paths from text\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphRAGStore\n",
    "\n",
    "The `GraphRAGStore` class is an extension of the `SimplePropertyGraphStore `class, designed to implement GraphRAG pipeline. Here's a breakdown of its key components and functions:\n",
    "\n",
    "\n",
    "The class uses community detection algorithms to group related nodes in the graph and then it generates summaries for each community using an LLM.\n",
    "\n",
    "\n",
    "**Key Methods:**\n",
    "\n",
    "`build_communities():`\n",
    "\n",
    "1. Converts the internal graph representation to a NetworkX graph.\n",
    "\n",
    "2. Applies the hierarchical Leiden algorithm for community detection.\n",
    "\n",
    "3. Collects detailed information about each community.\n",
    "\n",
    "4. Generates summaries for each community.\n",
    "\n",
    "`generate_community_summary(text):`\n",
    "\n",
    "1. Uses LLM to generate a summary of the relationships in a community.\n",
    "2. The summary includes entity names and a synthesis of relationship descriptions.\n",
    "\n",
    "`_create_nx_graph():`\n",
    "\n",
    "1. Converts the internal graph representation to a NetworkX graph for community detection.\n",
    "\n",
    "`_collect_community_info(nx_graph, clusters):`\n",
    "\n",
    "1. Collects detailed information about each node based on its community.\n",
    "2. Creates a string representation of each relationship within a community.\n",
    "\n",
    "`_summarize_communities(community_info):`\n",
    "\n",
    "1. Generates and stores summaries for each community using LLM.\n",
    "\n",
    "`get_community_summaries():`\n",
    "\n",
    "1. Returns the community summaries by building them if not already done."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from llama_index.core.graph_stores import SimplePropertyGraphStore\n",
    "import networkx as nx\n",
    "from graspologic.partition import hierarchical_leiden\n",
    "\n",
    "from llama_index.core.llms import ChatMessage\n",
    "\n",
    "\n",
    "class GraphRAGStore(SimplePropertyGraphStore):\n",
    "    community_summary = {}\n",
    "    max_cluster_size = 5\n",
    "\n",
    "    def generate_community_summary(self, text):\n",
    "        \"\"\"Generate summary for a given text using an LLM.\"\"\"\n",
    "        messages = [\n",
    "            ChatMessage(\n",
    "                role=\"system\",\n",
    "                content=(\n",
    "                    \"You are provided with a set of relationships from a knowledge graph, each represented as \"\n",
    "                    \"entity1->entity2->relation->relationship_description. Your task is to create a summary of these \"\n",
    "                    \"relationships. The summary should include the names of the entities involved and a concise synthesis \"\n",
    "                    \"of the relationship descriptions. The goal is to capture the most critical and relevant details that \"\n",
    "                    \"highlight the nature and significance of each relationship. Ensure that the summary is coherent and \"\n",
    "                    \"integrates the information in a way that emphasizes the key aspects of the relationships.\"\n",
    "                ),\n",
    "            ),\n",
    "            ChatMessage(role=\"user\", content=text),\n",
    "        ]\n",
    "        response = llm.chat(messages)\n",
    "        clean_response = re.sub(r\"^assistant:\\s*\", \"\", str(response)).strip()\n",
    "        return clean_response\n",
    "\n",
    "    def build_communities(self):\n",
    "        \"\"\"Builds communities from the graph and summarizes them.\"\"\"\n",
    "        nx_graph = self._create_nx_graph()\n",
    "        community_hierarchical_clusters = hierarchical_leiden(\n",
    "            nx_graph, max_cluster_size=self.max_cluster_size\n",
    "        )\n",
    "        community_info = self._collect_community_info(\n",
    "            nx_graph, community_hierarchical_clusters\n",
    "        )\n",
    "        self._summarize_communities(community_info)\n",
    "\n",
    "    def _create_nx_graph(self):\n",
    "        \"\"\"Converts internal graph representation to NetworkX graph.\"\"\"\n",
    "        nx_graph = nx.Graph()\n",
    "        for node in self.graph.nodes.values():\n",
    "            nx_graph.add_node(str(node))\n",
    "        for relation in self.graph.relations.values():\n",
    "            nx_graph.add_edge(\n",
    "                relation.source_id,\n",
    "                relation.target_id,\n",
    "                relationship=relation.label,\n",
    "                description=relation.properties[\"relationship_description\"],\n",
    "            )\n",
    "        return nx_graph\n",
    "\n",
    "    def _collect_community_info(self, nx_graph, clusters):\n",
    "        \"\"\"Collect detailed information for each node based on their community.\"\"\"\n",
    "        community_mapping = {item.node: item.cluster for item in clusters}\n",
    "        community_info = {}\n",
    "        for item in clusters:\n",
    "            cluster_id = item.cluster\n",
    "            node = item.node\n",
    "            if cluster_id not in community_info:\n",
    "                community_info[cluster_id] = []\n",
    "\n",
    "            for neighbor in nx_graph.neighbors(node):\n",
    "                if community_mapping[neighbor] == cluster_id:\n",
    "                    edge_data = nx_graph.get_edge_data(node, neighbor)\n",
    "                    if edge_data:\n",
    "                        detail = f\"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}\"\n",
    "                        community_info[cluster_id].append(detail)\n",
    "        return community_info\n",
    "\n",
    "    def _summarize_communities(self, community_info):\n",
    "        \"\"\"Generate and store summaries for each community.\"\"\"\n",
    "        for community_id, details in community_info.items():\n",
    "            details_text = (\n",
    "                \"\\n\".join(details) + \".\"\n",
    "            )  # Ensure it ends with a period\n",
    "            self.community_summary[\n",
    "                community_id\n",
    "            ] = self.generate_community_summary(details_text)\n",
    "\n",
    "    def get_community_summaries(self):\n",
    "        \"\"\"Returns the community summaries, building them if not already done.\"\"\"\n",
    "        if not self.community_summary:\n",
    "            self.build_communities()\n",
    "        return self.community_summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphRAGQueryEngine\n",
    "\n",
    "The GraphRAGQueryEngine class is a custom query engine designed to process queries using the GraphRAG approach. It leverages the community summaries generated by the GraphRAGStore to answer user queries. Here's a breakdown of its functionality:\n",
    "\n",
    "**Main Components:**\n",
    "\n",
    "`graph_store:` An instance of GraphRAGStore, which contains the community summaries.\n",
    "`llm:` A Language Model (LLM) used for generating and aggregating answers.\n",
    "\n",
    "\n",
    "**Key Methods:**\n",
    "\n",
    "`custom_query(query_str: str)`\n",
    "\n",
    "1. This is the main entry point for processing a query. It retrieves community summaries, generates answers from each summary, and then aggregates these answers into a final response.\n",
    "\n",
    "`generate_answer_from_summary(community_summary, query):`\n",
    "\n",
    "1. Generates an answer for the query based on a single community summary.\n",
    "Uses the LLM to interpret the community summary in the context of the query.\n",
    "\n",
    "`aggregate_answers(community_answers):`\n",
    "\n",
    "1. Combines individual answers from different communities into a coherent final response.\n",
    "2. Uses the LLM to synthesize multiple perspectives into a single, concise answer.\n",
    "\n",
    "\n",
    "**Query Processing Flow:**\n",
    "\n",
    "1. Retrieve community summaries from the graph store.\n",
    "2. For each community summary, generate a specific answer to the query.\n",
    "3. Aggregate all community-specific answers into a final, coherent response.\n",
    "\n",
    "\n",
    "**Example usage:**\n",
    "\n",
    "```\n",
    "query_engine = GraphRAGQueryEngine(graph_store=graph_store, llm=llm)\n",
    "\n",
    "response = query_engine.query(\"query\")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import CustomQueryEngine\n",
    "from llama_index.core.llms import LLM\n",
    "\n",
    "\n",
    "class GraphRAGQueryEngine(CustomQueryEngine):\n",
    "    graph_store: GraphRAGStore\n",
    "    llm: LLM\n",
    "\n",
    "    def custom_query(self, query_str: str) -> str:\n",
    "        \"\"\"Process all community summaries to generate answers to a specific query.\"\"\"\n",
    "        community_summaries = self.graph_store.get_community_summaries()\n",
    "        community_answers = [\n",
    "            self.generate_answer_from_summary(community_summary, query_str)\n",
    "            for _, community_summary in community_summaries.items()\n",
    "        ]\n",
    "\n",
    "        final_answer = self.aggregate_answers(community_answers)\n",
    "        return final_answer\n",
    "\n",
    "    def generate_answer_from_summary(self, community_summary, query):\n",
    "        \"\"\"Generate an answer from a community summary based on a given query using LLM.\"\"\"\n",
    "        prompt = (\n",
    "            f\"Given the community summary: {community_summary}, \"\n",
    "            f\"how would you answer the following query? Query: {query}\"\n",
    "        )\n",
    "        messages = [\n",
    "            ChatMessage(role=\"system\", content=prompt),\n",
    "            ChatMessage(\n",
    "                role=\"user\",\n",
    "                content=\"I need an answer based on the above information.\",\n",
    "            ),\n",
    "        ]\n",
    "        response = self.llm.chat(messages)\n",
    "        cleaned_response = re.sub(r\"^assistant:\\s*\", \"\", str(response)).strip()\n",
    "        return cleaned_response\n",
    "\n",
    "    def aggregate_answers(self, community_answers):\n",
    "        \"\"\"Aggregate individual community answers into a final, coherent response.\"\"\"\n",
    "        # intermediate_text = \" \".join(community_answers)\n",
    "        prompt = \"Combine the following intermediate answers into a final, concise response.\"\n",
    "        messages = [\n",
    "            ChatMessage(role=\"system\", content=prompt),\n",
    "            ChatMessage(\n",
    "                role=\"user\",\n",
    "                content=f\"Intermediate answers: {community_answers}\",\n",
    "            ),\n",
    "        ]\n",
    "        final_response = self.llm.chat(messages)\n",
    "        cleaned_final_response = re.sub(\n",
    "            r\"^assistant:\\s*\", \"\", str(final_response)\n",
    "        ).strip()\n",
    "        return cleaned_final_response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  Build End to End GraphRAG Pipeline\n",
    "\n",
    "Now that we have defined all the necessary components, let’s construct the GraphRAG pipeline:\n",
    "\n",
    "1. Create nodes/chunks from the text.\n",
    "2. Build a PropertyGraphIndex using `GraphRAGExtractor` and `GraphRAGStore`.\n",
    "3. Construct communities and generate a summary for each community using the graph built above.\n",
    "4. Create a `GraphRAGQueryEngine` and begin querying."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load documents\n",
    "documents = SimpleDirectoryReader(input_files=[\"metagpt.pdf\"]).load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create nodes/ chunks from the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "\n",
    "splitter = SentenceSplitter(\n",
    "    chunk_size=1024,\n",
    "    chunk_overlap=20,\n",
    ")\n",
    "nodes = splitter.get_nodes_from_documents(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build PropertyGraphIndex using `GraphRAGExtractor` and `GraphRAGStore`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "KG_TRIPLET_EXTRACT_TMPL = \"\"\"\n",
    "-Goal-\n",
    "Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.\n",
    "Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.\n",
    "\n",
    "-Steps-\n",
    "1. Identify all entities. For each identified entity, extract the following information:\n",
    "- entity_name: Name of the entity, capitalized\n",
    "- entity_type: Type of the entity\n",
    "- entity_description: Comprehensive description of the entity's attributes and activities\n",
    "Format each entity as (\"entity\"$$$$<entity_name>$$$$<entity_type>$$$$<entity_description>)\n",
    "\n",
    "2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.\n",
    "For each pair of related entities, extract the following information:\n",
    "- source_entity: name of the source entity, as identified in step 1\n",
    "- target_entity: name of the target entity, as identified in step 1\n",
    "- relation: relationship between source_entity and target_entity\n",
    "- relationship_description: explanation as to why you think the source entity and the target entity are related to each other\n",
    "\n",
    "Format each relationship as (\"relationship\"$$$$<source_entity>$$$$<target_entity>$$$$<relation>$$$$<relationship_description>)\n",
    "\n",
    "3. When finished, output.\n",
    "\n",
    "-Real Data-\n",
    "######################\n",
    "text: {text}\n",
    "######################\n",
    "output:\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_pattern = r'\\(\"entity\"\\$\\$\\$\\$\"(.+?)\"\\$\\$\\$\\$\"(.+?)\"\\$\\$\\$\\$\"(.+?)\"\\)'\n",
    "relationship_pattern = r'\\(\"relationship\"\\$\\$\\$\\$\"(.+?)\"\\$\\$\\$\\$\"(.+?)\"\\$\\$\\$\\$\"(.+?)\"\\$\\$\\$\\$\"(.+?)\"\\)'\n",
    "\n",
    "\n",
    "def parse_fn(response_str: str) -> Any:\n",
    "    entities = re.findall(entity_pattern, response_str)\n",
    "    relationships = re.findall(relationship_pattern, response_str)\n",
    "    return entities, relationships\n",
    "\n",
    "\n",
    "kg_extractor = GraphRAGExtractor(\n",
    "    llm=llm,\n",
    "    extract_prompt=KG_TRIPLET_EXTRACT_TMPL,\n",
    "    max_paths_per_chunk=2,\n",
    "    parse_fn=parse_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import PropertyGraphIndex\n",
    "\n",
    "index = PropertyGraphIndex(\n",
    "    nodes=nodes,\n",
    "    property_graph_store=GraphRAGStore(),\n",
    "    kg_extractors=[kg_extractor],\n",
    "    show_progress=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(index.property_graph_store.graph.nodes.values())[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(index.property_graph_store.graph.relations.values())[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(index.property_graph_store.graph.relations.values())[0].properties[\n",
    "    \"relationship_description\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build communities\n",
    "\n",
    "This will create communities and summary for each community."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.property_graph_store.build_communities()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create QueryEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine = GraphRAGQueryEngine(\n",
    "    graph_store=index.property_graph_store, llm=llm\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Querying"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(\"What are the main discussed in the document?\")\n",
    "display(Markdown(f\"{response.response}\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building an Agent Reasoning Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_doc_tools(\n",
    "    file_path: str,\n",
    "    name: str,\n",
    ") -> str:\n",
    "    \"\"\"Get vector query and summary query tools from a document.\"\"\"\n",
    "\n",
    "    # load documents\n",
    "    documents = SimpleDirectoryReader(input_files=[file_path]).load_data()\n",
    "    splitter = SentenceSplitter(chunk_size=1024)\n",
    "    nodes = splitter.get_nodes_from_documents(documents)\n",
    "\n",
    "    index = PropertyGraphIndex(\n",
    "        nodes=nodes,\n",
    "        property_graph_store=GraphRAGStore(),\n",
    "        kg_extractors=[kg_extractor],\n",
    "        show_progress=True,\n",
    "    )\n",
    "    index.property_graph_store.build_communities()\n",
    "    query_engine = GraphRAGQueryEngine(\n",
    "        graph_store=index.property_graph_store, llm=llm\n",
    "    )\n",
    "\n",
    "    summary_index = SummaryIndex(nodes)\n",
    "\n",
    "    def vector_query(\n",
    "        query: str, page_numbers: Optional[List[str]] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use to answer questions over the MetaGPT paper.\n",
    "\n",
    "        Useful if you have specific questions over the MetaGPT paper.\n",
    "        Always leave page_numbers as None UNLESS there is a specific page you want to search for.\n",
    "\n",
    "        Args:\n",
    "            query (str): the string query to be embedded.\n",
    "            page_numbers (Optional[List[str]]): Filter by set of pages. Leave as NONE\n",
    "                if we want to perform a vector search\n",
    "                over all pages. Otherwise, filter by the set of specified pages.\n",
    "\n",
    "        \"\"\"\n",
    "        vector_query_engine = query_engine\n",
    "        response = vector_query_engine.query(query)\n",
    "        return response\n",
    "\n",
    "    vector_query_tool = FunctionTool.from_defaults(\n",
    "        name=f\"vector_tool_{name}\", fn=vector_query\n",
    "    )\n",
    "\n",
    "    def summary_query(\n",
    "        query: str,\n",
    "    ) -> str:\n",
    "        \"\"\"Perform a summary of document\n",
    "        query (str): the string query to be embedded.\n",
    "        \"\"\"\n",
    "        summary_engine = summary_index.as_query_engine(\n",
    "            response_mode=\"tree_summarize\",\n",
    "            use_async=True,\n",
    "        )\n",
    "\n",
    "        response = summary_engine.query(query)\n",
    "        return response\n",
    "\n",
    "    summary_tool = FunctionTool.from_defaults(\n",
    "        fn=summary_query, name=f\"summary_tool_{name}\"\n",
    "    )\n",
    "\n",
    "    return vector_query_tool, summary_tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_query_tool, summary_tool = get_doc_tools(\"metagpt.pdf\", \"metagpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Agent\n",
    "agent = FunctionAgent(\n",
    "    tools=[summary_tool, vector_query_tool],\n",
    "    llm=vertex_gemini,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = await agent.run(\n",
    "    \"what are agent roles in MetaGPT, \"\n",
    "    \"and then how they communicate with each other.\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "llamaindex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
